{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f4b72fd7",
   "metadata": {},
   "source": [
    "## Description:\n",
    "这个是MMOE模型的demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "324f6ca1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error\n",
    "from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat\n",
    "from deepctr.feature_column import get_feature_names\n",
    "\n",
    "from MMOE import MMOE\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "577efc09",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '../data_process'\n",
    "data = pd.read_csv(os.path.join(data_path, 'train_data.csv'), index_col=0, parse_dates=['expo_time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "88a84ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选择出需要用到的列\n",
    "use_cols = ['user_id', 'article_id', 'expo_time', 'net_status', 'exop_position', 'duration', 'device', 'city', 'age', 'gender', 'img_num', 'cat_1', 'click']\n",
    "data_new = data[use_cols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ce2ea472",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 由于这个data_new的数据量还是太大， 我电脑训练不动， 所以这里再进行一波抽样\n",
    "users = set(data_new['user_id'])\n",
    "sampled_users = random.sample(users, 1000)\n",
    "data_new = data_new[data_new['user_id'].isin(sampled_users)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "84eb94aa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>article_id</th>\n",
       "      <th>expo_time</th>\n",
       "      <th>net_status</th>\n",
       "      <th>exop_position</th>\n",
       "      <th>duration</th>\n",
       "      <th>device</th>\n",
       "      <th>city</th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>img_num</th>\n",
       "      <th>cat_1</th>\n",
       "      <th>click</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10661</th>\n",
       "      <td>60</td>\n",
       "      <td>2174</td>\n",
       "      <td>2021-06-30 13:36:57</td>\n",
       "      <td>0</td>\n",
       "      <td>17</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>15</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10662</th>\n",
       "      <td>60</td>\n",
       "      <td>4458</td>\n",
       "      <td>2021-06-30 13:36:57</td>\n",
       "      <td>0</td>\n",
       "      <td>21</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.033149</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10663</th>\n",
       "      <td>60</td>\n",
       "      <td>4037</td>\n",
       "      <td>2021-06-30 13:40:23</td>\n",
       "      <td>0</td>\n",
       "      <td>24</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.033149</td>\n",
       "      <td>12</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10664</th>\n",
       "      <td>60</td>\n",
       "      <td>3109</td>\n",
       "      <td>2021-06-30 13:36:57</td>\n",
       "      <td>0</td>\n",
       "      <td>14</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.038674</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10665</th>\n",
       "      <td>60</td>\n",
       "      <td>14125</td>\n",
       "      <td>2021-07-03 06:10:46</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.027624</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       user_id  article_id           expo_time  net_status  exop_position  \\\n",
       "10661       60        2174 2021-06-30 13:36:57           0             17   \n",
       "10662       60        4458 2021-06-30 13:36:57           0             21   \n",
       "10663       60        4037 2021-06-30 13:40:23           0             24   \n",
       "10664       60        3109 2021-06-30 13:36:57           0             14   \n",
       "10665       60       14125 2021-07-03 06:10:46           0              7   \n",
       "\n",
       "       duration  device  city  age  gender   img_num  cat_1  click  \n",
       "10661         0     174   237    1       1  0.000000     15      0  \n",
       "10662         0     174   237    1       1  0.033149      1      0  \n",
       "10663         0     174   237    1       1  0.033149     12      0  \n",
       "10664         0     174   237    1       1  0.038674     13      0  \n",
       "10665         0     174   237    1       1  0.027624     13      0  "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_new.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df0b628c",
   "metadata": {},
   "source": [
    "## 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e380f114",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 处理img_num\n",
    "def transform(x):\n",
    "    if x == '上海':\n",
    "        return 0\n",
    "    elif isinstance(x, float):\n",
    "        return float(x)\n",
    "    else:\n",
    "        return float(eval(x))\n",
    "data_new['img_num'] = data_new['img_num'].apply(lambda x: transform(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "46f77eb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_id_raw = data_new[['user_id']].drop_duplicates('user_id')\n",
    "doc_id_raw = data_new[['article_id']].drop_duplicates('article_id')\n",
    "\n",
    "# 简单数据预处理\n",
    "sparse_features = [\n",
    "    'user_id', 'article_id', 'net_status', 'exop_position', 'device', 'city', 'age', 'gender', 'cat_1'\n",
    "]\n",
    "dense_features = [\n",
    "    'img_num'\n",
    "]\n",
    "\n",
    "# 填充缺失值\n",
    "data_new[sparse_features] = data_new[sparse_features].fillna('-1')\n",
    "data_new[dense_features] = data_new[dense_features].fillna(0)\n",
    "\n",
    "# 归一化\n",
    "mms = MinMaxScaler(feature_range=(0, 1))\n",
    "data_new[dense_features] = mms.fit_transform(data_new[dense_features])\n",
    "\n",
    "feature_max_idx = {}\n",
    "for feat in sparse_features:\n",
    "    lbe = LabelEncoder()\n",
    "    data_new[feat] = lbe.fit_transform(data_new[feat])\n",
    "    feature_max_idx[feat] = data_new[feat].max() + 1\n",
    "\n",
    "# 构建用户id词典和doc的id词典，方便从用户idx找到原始的id\n",
    "# user_id_enc = data[['user_id']].drop_duplicates('user_id')\n",
    "# doc_id_enc = data[['article_id']].drop_duplicates('article_id')\n",
    "# user_idx_2_rawid = dict(zip(user_id_enc['user_id'], user_id_raw['user_id']))\n",
    "# doc_idx_2_rawid = dict(zip(doc_id_enc['article_id'], doc_id_raw['article_id']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ad2691ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 划分数据集  这里按照曝光时间划分\n",
    "train_data = data_new[data_new['expo_time'] < '2021-07-06']\n",
    "test_data = data_new[data_new['expo_time'] >= '2021-07-06']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad199918",
   "metadata": {},
   "source": [
    "## 特征封装"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d0e5ea3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fixlen_feature_columns = [SparseFeat(feat, feature_max_idx[feat], embedding_dim=4) for feat in sparse_features] \\\n",
    "                         + [DenseFeat(feat, 1) for feat in dense_features]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c9f538ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 划分dnn和linear特征\n",
    "dnn_features_columns = fixlen_feature_columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f5355501",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_names = get_feature_names(dnn_features_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c6a5f286",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AttributeError: 'numpy.dtype[int64]' object has no attribute 'base_dtype' \n",
    "# Keras需要把输入声明为Keras张量，其他的比如numpy张量作为输入不好使\n",
    "train_model_input = {name: tf.keras.backend.constant(train_data[name]) for name in feature_names}\n",
    "test_model_input = {name: tf.keras.backend.constant(test_data[name]) for name in feature_names}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e4803b7",
   "metadata": {},
   "source": [
    "## 模型训练和预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "39172a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MMOE(dnn_features_columns, tower_dnn_hidden_units=[], task_types=['regression', 'binary'], \n",
    "             task_names=['duration', 'click'])\n",
    "model.compile(\"adam\", loss={\"duration\": \"mean_squared_error\", \"click\": \"binary_crossentropy\"}, \n",
    "              loss_weights={\"duration\": 0.02, \"click\": 0.98},\n",
    "              metrics={\"duration\": \"mae\", \"click\": \"binary_crossentropy\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "051a9186",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_duration = tf.keras.backend.constant(train_data['duration'].values)\n",
    "label_click = tf.keras.backend.constant(train_data['click'].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c32ea007",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 123209 samples, validate on 30803 samples\n",
      "Epoch 1/10\n",
      "123209/123209 [==============================] - 12s 99us/sample - loss: 116.1293 - duration_loss: 5786.4404 - click_loss: 0.3734 - duration_mae: 33.8007 - click_binary_crossentropy: 0.3734 - val_loss: 166.6962 - val_duration_loss: 8344.8467 - val_click_loss: 0.4353 - val_duration_mae: 58.8950 - val_click_binary_crossentropy: 0.4350\n",
      "Epoch 2/10\n",
      "123209/123209 [==============================] - 12s 97us/sample - loss: 109.4779 - duration_loss: 5454.9507 - click_loss: 0.3595 - duration_mae: 31.8456 - click_binary_crossentropy: 0.3595 - val_loss: 166.9437 - val_duration_loss: 8357.1445 - val_click_loss: 0.4568 - val_duration_mae: 54.7640 - val_click_binary_crossentropy: 0.4565\n",
      "Epoch 3/10\n",
      "123209/123209 [==============================] - 12s 97us/sample - loss: 103.7577 - duration_loss: 5169.7749 - click_loss: 0.3481 - duration_mae: 30.4120 - click_binary_crossentropy: 0.3481 - val_loss: 177.6279 - val_duration_loss: 8891.0713 - val_click_loss: 0.4713 - val_duration_mae: 56.4576 - val_click_binary_crossentropy: 0.4709\n",
      "Epoch 4/10\n",
      "123209/123209 [==============================] - 12s 97us/sample - loss: 99.9210 - duration_loss: 4979.6274 - click_loss: 0.3426 - duration_mae: 29.4238 - click_binary_crossentropy: 0.3425 - val_loss: 184.7347 - val_duration_loss: 9249.1885 - val_click_loss: 0.4250 - val_duration_mae: 57.1670 - val_click_binary_crossentropy: 0.4246\n",
      "Epoch 5/10\n",
      "123209/123209 [==============================] - 12s 97us/sample - loss: 96.8891 - duration_loss: 4828.1299 - click_loss: 0.3378 - duration_mae: 28.7680 - click_binary_crossentropy: 0.3378 - val_loss: 194.2312 - val_duration_loss: 9722.2783 - val_click_loss: 0.4478 - val_duration_mae: 57.8451 - val_click_binary_crossentropy: 0.4475\n",
      "Epoch 6/10\n",
      "123209/123209 [==============================] - 12s 98us/sample - loss: 94.7821 - duration_loss: 4721.8525 - click_loss: 0.3330 - duration_mae: 28.2337 - click_binary_crossentropy: 0.3330 - val_loss: 203.2876 - val_duration_loss: 10170.6797 - val_click_loss: 0.5447 - val_duration_mae: 58.5539 - val_click_binary_crossentropy: 0.5443\n",
      "Epoch 7/10\n",
      "123209/123209 [==============================] - 12s 97us/sample - loss: 93.0133 - duration_loss: 4633.9004 - click_loss: 0.3281 - duration_mae: 27.8672 - click_binary_crossentropy: 0.3281 - val_loss: 205.4053 - val_duration_loss: 10276.6504 - val_click_loss: 0.5332 - val_duration_mae: 57.2810 - val_click_binary_crossentropy: 0.5328\n",
      "Epoch 8/10\n",
      "123209/123209 [==============================] - 12s 97us/sample - loss: 91.4517 - duration_loss: 4556.3979 - click_loss: 0.3228 - duration_mae: 27.4333 - click_binary_crossentropy: 0.3228 - val_loss: 211.3673 - val_duration_loss: 10577.6230 - val_click_loss: 0.4885 - val_duration_mae: 57.6894 - val_click_binary_crossentropy: 0.4880\n",
      "Epoch 9/10\n",
      "123209/123209 [==============================] - 12s 98us/sample - loss: 90.1165 - duration_loss: 4490.1094 - click_loss: 0.3182 - duration_mae: 27.0779 - click_binary_crossentropy: 0.3182 - val_loss: 208.9492 - val_duration_loss: 10456.8262 - val_click_loss: 0.4825 - val_duration_mae: 57.0025 - val_click_binary_crossentropy: 0.4820\n",
      "Epoch 10/10\n",
      "123209/123209 [==============================] - 12s 97us/sample - loss: 88.7979 - duration_loss: 4428.1157 - click_loss: 0.3146 - duration_mae: 26.8201 - click_binary_crossentropy: 0.3146 - val_loss: 226.1269 - val_duration_loss: 11316.2451 - val_click_loss: 0.4480 - val_duration_mae: 60.3714 - val_click_binary_crossentropy: 0.4475\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(train_model_input, [label_duration, label_click],\n",
    "                        batch_size=128, epochs=10, verbose=1, validation_split=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8d671e57",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_ans = model.predict(test_model_input, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "32dc4aa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test click AUC 0.647\n"
     ]
    }
   ],
   "source": [
    "print(\"test click AUC\", round(roc_auc_score(test_data['click'], pred_ans[1]), 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "6c377e25",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test duration 48.4138\n"
     ]
    }
   ],
   "source": [
    "print(\"test duration\", round(mean_absolute_error(test_data['duration'], pred_ans[0]), 4))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
